{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f6df7d4-cbee-41f5-8e4d-e4e2e424b49e",
   "metadata": {},
   "outputs": [],
   "source": [
    "fish = 11\n",
    "\n",
    "# once the results are saved in the results folder, ensure that BASE_DIR is pointing to where\n",
    "# the saved results are. Also check that the images are in the indicated directory img_dir.\n",
    "BASE_DIR = os.path.abspath(\n",
    "    os.path.join(\n",
    "        os.getcwd(), os.pardir, os.pardir, \"results\", \"experiment_4b\"\n",
    "    )\n",
    ")\n",
    "img_dir = os.path.join(\n",
    "        os.getcwd(), os.pardir, os.pardir, \"exp1-4_data\", \"data_prepped_for_models\", f\"fish{fish}_images\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33510a78-4522-4677-b438-e81022c43ab9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import tifffile\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plane0_file = glob.glob(os.path.join(img_dir, \"plane_0.*\"))[0]\n",
    "\n",
    "im = tifffile.imread(plane0_file).astype(float)\n",
    "im = (im - im.min()) / (im.max() - im.min())\n",
    "im *= 1.5\n",
    "im = np.clip(im, 0, 1)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(6, 6), dpi=300)\n",
    "ax.imshow(im, cmap=\"viridis\")\n",
    "\n",
    "bar_len_px = SCALE_BAR_UM / UM_PER_PIXEL\n",
    "x0, y0 = 0.02, 0.95\n",
    "\n",
    "# convert axes fraction to data coordinates\n",
    "xdata0, ydata0 = ax.transAxes.transform((x0, 1 - y0))\n",
    "xdata0, ydata0 = ax.transData.inverted().transform((xdata0, ydata0))\n",
    "\n",
    "# draw scale bar\n",
    "ax.hlines(\n",
    "    y=ydata0,\n",
    "    xmin=xdata0,\n",
    "    xmax=xdata0 + bar_len_px,\n",
    "    colors=\"white\",\n",
    "    linewidth=3,\n",
    "    transform=ax.transData,\n",
    "    clip_on=False\n",
    ")\n",
    "ax.text(\n",
    "    xdata0 + bar_len_px / 2,\n",
    "    ydata0 - bar_len_px * 0.3,\n",
    "    f\"{SCALE_BAR_UM} µm\",\n",
    "    color=\"white\",\n",
    "    ha=\"center\",\n",
    "    va=\"top\",\n",
    "    fontsize=12\n",
    ")\n",
    "\n",
    "# label\n",
    "ax.text(\n",
    "    0.01,\n",
    "    0.99,\n",
    "    \"GCaMP6s\",\n",
    "    transform=ax.transAxes,\n",
    "    va=\"top\",\n",
    "    ha=\"left\",\n",
    "    color=\"white\",\n",
    "    weight=\"bold\",\n",
    "    fontsize=16\n",
    ")\n",
    "\n",
    "ax.axis(\"off\")\n",
    "\n",
    "out = os.path.join(BASE_DIR, f\"fish{fish}_calcium.pdf\")\n",
    "fig.savefig(out, bbox_inches=\"tight\", transparent=True)\n",
    "plt.close(fig)\n",
    "print(\"saved\", out)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2e8f86c-6f25-4223-8f96-069fdf579b73",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os ,glob ,ast ,numpy as np ,pandas as pd ,matplotlib .pyplot as plt ,tifffile \n",
    "\n",
    "fish =11 \n",
    "TOP_K =30 \n",
    "DOT_MIN ,DOT_MAX =50 ,300 \n",
    "\n",
    "def bright_plane0 (f ,boost =1.5 ):\n",
    "    plane0 =glob .glob (os .path .join (img_dir ,\"plane_0.*\"))[0 ]\n",
    "    im =tifffile .imread (plane0 ).astype (float )\n",
    "    im =(im -im .min ())/(im .max ()-im .min ())\n",
    "    return np .clip (im *boost ,0 ,1 )\n",
    "\n",
    "def load_saliency (f ):\n",
    "\n",
    "    root =os .path .join (BASE_DIR ,f\"fish{f}\")\n",
    "    print (root )\n",
    "    vecs =[\n",
    "    np .load (os .path .join (r ,\"importance.npy\"))\n",
    "    for r ,_ ,fs in os .walk (root )if \"importance.npy\"in fs \n",
    "    ]\n",
    "    if not vecs :\n",
    "        raise FileNotFoundError (\n",
    "        f\"No importance.npy files under {root}. \"\n",
    "        \"Check the folder structure or BASE_DIR.\"\n",
    "        )\n",
    "    return np .vstack (vecs ).mean (0 )\n",
    "\n",
    "def plane0_coords (f ):\n",
    "    h5 =os .path .join (os .path .dirname (BASE_DIR ),os .pardir ,\n",
    "    f\"fish{f}_images\",\"functional_types_df.h5\")\n",
    "    h5 =os .path .normpath (h5 )\n",
    "    df =pd .read_hdf (h5 )\n",
    "    df0 =df [df .plane ==\"plane_0\"]\n",
    "\n",
    "    coords =np .vstack (\n",
    "    df0 .neur_coords .apply (\n",
    "    lambda v :ast .literal_eval (v )if isinstance (v ,str )else v \n",
    "    )\n",
    "    )\n",
    "\n",
    "    return coords ,df0 .index .values .astype (int )\n",
    "\n",
    "sal =load_saliency (fish )\n",
    "coords ,idx =plane0_coords (fish )\n",
    "sal0 =sal [idx ]\n",
    "\n",
    "top =np .argsort (sal0 )[-TOP_K :][::-1 ]\n",
    "arr =sal0 [top ]\n",
    "sizes =(arr -arr .min ())/(np .ptp (arr )+1e-9 )\n",
    "sizes =sizes *(DOT_MAX -DOT_MIN )+DOT_MIN \n",
    "\n",
    "bg =bright_plane0 (fish )\n",
    "fig ,ax =plt .subplots (figsize =(6 ,6 ))\n",
    "ax .imshow (bg ,cmap =\"gray\",vmin =0 ,vmax =1 )\n",
    "ax .scatter (coords [top ,0 ],coords [top ,1 ],\n",
    "s =sizes ,c =\"gray\",edgecolors =\"white\",alpha =.8 )\n",
    "\n",
    "ax .text (0.01 ,0.99 ,f\"Fish {fish}\",transform =ax .transAxes ,\n",
    "va =\"top\",ha =\"left\",color =\"white\",weight =\"bold\",fontsize =16 )\n",
    "\n",
    "SCALE_BAR_UM =50 \n",
    "UM_PER_PIXEL =0.6 \n",
    "bar_len_px =SCALE_BAR_UM /UM_PER_PIXEL \n",
    "x0 ,y0 =0.02 ,0.95 \n",
    "\n",
    "xdata0 ,ydata0 =ax .transAxes .transform ((x0 ,1 -y0 ))\n",
    "xdata0 ,ydata0 =ax .transData .inverted ().transform ((xdata0 ,ydata0 ))\n",
    "ax .hlines (y =ydata0 ,xmin =xdata0 ,xmax =xdata0 +bar_len_px ,\n",
    "colors =\"white\",linewidth =3 ,transform =ax .transData ,clip_on =False )\n",
    "ax .text (xdata0 +bar_len_px /2 ,ydata0 -bar_len_px *0.3 ,\n",
    "f\"{SCALE_BAR_UM} µm\",color =\"white\",ha =\"center\",va =\"top\",fontsize =12 )\n",
    "\n",
    "ax .axis (\"off\")\n",
    "fig .patch .set_alpha (0 )\n",
    "ax .patch .set_alpha (0 )\n",
    "\n",
    "out =os .path .join (BASE_DIR ,f\"fish{fish}_overlay.pdf\")\n",
    "fig .savefig (out ,bbox_inches =\"tight\",transparent =True )\n",
    "plt .close (fig )\n",
    "print (\"saved \",out )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf7e1dca-1ea4-4581-9b9c-bf286a6288e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os ,glob ,ast ,numpy as np ,pandas as pd ,matplotlib .pyplot as plt ,tifffile \n",
    "\n",
    "fish =11 \n",
    "COLOR_MAP ={\"Pt\":\"orange\",\"Hb\":\"green\",\"Other\":\"purple\"}\n",
    "\n",
    "def bright_plane0 (f ,boost =1.5 ):\n",
    "    plane0 =glob .glob (os .path .join (img_dir ,\"plane_0.*\"))[0 ]\n",
    "    im =tifffile .imread (plane0 ).astype (float )\n",
    "    im =(im -im .min ())/(im .max ()-im .min ())\n",
    "    return np .clip (im *boost ,0 ,1 )\n",
    "\n",
    "def plane0_table (f ):\n",
    "    h5 =os .path .normpath (os .path .join (os .path .dirname (BASE_DIR ),os .pardir ,\n",
    "    f\"fish{f}_images\",\"functional_types_df.h5\"))\n",
    "    df =pd .read_hdf (h5 )\n",
    "    return df [df .plane ==\"plane_0\"]\n",
    "\n",
    "df0 =plane0_table (fish ).copy ()\n",
    "df0 [\"coords\"]=df0 .neur_coords .apply (\n",
    "lambda v :ast .literal_eval (v )if isinstance (v ,str )else v \n",
    ")\n",
    "df0 [\"grp\"]=df0 .region .fillna (\"unknown\").apply (\n",
    "lambda r :r if r in (\"Pt\",\"Hb\")else \"Other\"\n",
    ")\n",
    "\n",
    "bg =bright_plane0 (fish )\n",
    "fig ,ax =plt .subplots (figsize =(6 ,6 ),dpi =300 )\n",
    "ax .imshow (bg ,cmap =\"gray\",vmin =0 ,vmax =1 )\n",
    "\n",
    "for group ,color in COLOR_MAP .items ():\n",
    "    coords_list =df0 .loc [df0 .grp ==group ,\"coords\"].tolist ()\n",
    "    if not coords_list :\n",
    "        continue \n",
    "    pts =np .vstack (coords_list )\n",
    "    ax .scatter (pts [:,0 ],pts [:,1 ],\n",
    "    s =40 ,c =color ,edgecolors =\"white\",alpha =.9 ,\n",
    "    label =group )\n",
    "\n",
    "ax .legend (loc =\"lower right\")\n",
    "ax .axis (\"off\")\n",
    "\n",
    "fig .patch .set_alpha (0 );ax .patch .set_alpha (0 )\n",
    "\n",
    "out =os .path .join (BASE_DIR ,f\"fish{fish}_clusters_overlay.pdf\")\n",
    "fig .savefig (out ,bbox_inches =\"tight\",transparent =True )\n",
    "plt .close (fig )\n",
    "print (\"saved \",out )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c6a1e8e-1b6e-483c-9c52-3f7e5b025ee0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os ,numpy as np ,pandas as pd ,matplotlib .pyplot as plt \n",
    "from scipy .stats import sem ,ttest_rel \n",
    "\n",
    "fish_list =[11 ,12 ,13 ]\n",
    "GROUPS =[\"Pt\",\"Hb\",\"Other\"]\n",
    "COLORS ={\"Pt\":\"orange\",\"Hb\":\"green\",\"Other\":\"purple\"}\n",
    "\n",
    "def load_saliency (f ):\n",
    "    root =os .path .join (BASE_DIR ,f\"fish{f}\")\n",
    "    print (root )\n",
    "    vecs =[\n",
    "    np .load (os .path .join (r ,\"importance.npy\"))\n",
    "    for r ,_ ,fs in os .walk (root )if \"importance.npy\"in fs \n",
    "    ]\n",
    "    if not vecs :\n",
    "        raise FileNotFoundError (\n",
    "        f\"No importance.npy files under {root}. \"\n",
    "        \"Check the folder structure or BASE_DIR.\"\n",
    "        )\n",
    "    return np .vstack (vecs ).mean (0 )\n",
    "\n",
    "def plane0_table (f ):\n",
    "    h5 =os .path .normpath (os .path .join (os .path .dirname (img_dir),os .pardir ,\n",
    "    f\"fish{f}_images\",\"functional_types_df.h5\"))\n",
    "    df =pd .read_hdf (h5 )\n",
    "    return df [df .plane ==\"plane_0\"]\n",
    "\n",
    "group_sals =[]\n",
    "for f in fish_list :\n",
    "    sal =load_saliency (f )\n",
    "    df =plane0_table (f ).copy ()\n",
    "    df [\"grp\"]=df .region .fillna (\"unknown\").apply (\n",
    "    lambda r :r if r in (\"Pt\",\"Hb\")else \"Other\"\n",
    "    )\n",
    "    vals =sal [df .index .astype (int )]\n",
    "    s =pd .Series (vals ,index =df .grp ).groupby (level =0 ).mean ()\n",
    "    group_sals .append (s )\n",
    "\n",
    "df_groups =pd .concat (group_sals ,axis =1 ).reindex (GROUPS )\n",
    "means =df_groups .mean (axis =1 )\n",
    "errs =df_groups .apply (sem ,axis =1 )\n",
    "\n",
    "x =np .arange (len (GROUPS ))\n",
    "fig ,ax =plt .subplots (figsize =(5 ,2 ),dpi =300 )\n",
    "ax .bar (x ,means ,yerr =errs ,capsize =5 ,\n",
    "color =[COLORS [g ]for g in GROUPS ])\n",
    "ax .set_xticks (x )\n",
    "ax .set_ylim (0.00045 ,0.00095 )\n",
    "ax .set_xticklabels (GROUPS )\n",
    "ax .set_ylabel (\"Mean saliency\")\n",
    "\n",
    "def draw_bracket (ax ,x1 ,x2 ,y ,h ,text ):\n",
    "    ax .plot ([x1 ,x1 ,x2 ,x2 ],[y ,y +h ,y +h ,y ],\n",
    "    lw =1.5 ,c =\"black\")\n",
    "    ax .text ((x1 +x2 )/2 ,y +h +0.00001 ,text ,\n",
    "    ha =\"center\",va =\"bottom\",fontsize =12 )\n",
    "\n",
    "print (\"Paired t-tests:\")\n",
    "offset =errs .max ()*0.1 \n",
    "h =errs .max ()*0.1 \n",
    "pairs =[(0 ,1 ),(0 ,2 ),(1 ,2 )]\n",
    "for i ,j in pairs :\n",
    "    grp_i ,grp_j =GROUPS [i ],GROUPS [j ]\n",
    "    vals_i ,vals_j =df_groups .loc [grp_i ],df_groups .loc [grp_j ]\n",
    "    t_stat ,p_val =ttest_rel (vals_i ,vals_j ,nan_policy =\"omit\")\n",
    "    print (f\"{grp_i} vs {grp_j}: p = {p_val:.4f}\")\n",
    "    if p_val <0.2 :\n",
    "        star =\"***\"if p_val <0.001 else \"**\"if p_val <0.01 else f\"p={round(p_val,1)}\"\n",
    "        y =max (means [i ]+errs [i ],means [j ]+errs [j ])+offset \n",
    "        draw_bracket (ax ,i ,j ,y ,h ,star )\n",
    "        offset +=h *2.8 \n",
    "\n",
    "fig .tight_layout ()\n",
    "out =os .path .join (BASE_DIR ,\"cluster_importance_barplot.pdf\")\n",
    "fig .savefig (out ,bbox_inches =\"tight\",transparent =True )\n",
    "plt .close (fig )\n",
    "print (\"saved \",out )\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
